{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n",
      "loading BERT model... "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight']\n",
      "- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "done\n",
      "loading CLIP model... done\n",
      "DEVICE:  cuda\n",
      "loading dataset:  forces_only_f ...done\n",
      "raw X: (10000, 953) \tY: (10000, 160)\n",
      "filtered X: (10000, 953) \tY: (10000, 160)\n",
      "Train X: (7000, 953) \tY: (7000, 160)\n",
      "Test  X: (2000, 953) \tY: (2000, 160)\n",
      "Val   X: (1000, 953) \tY: (1000, 160)\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' \n",
    "import numpy as np\n",
    "from src.motion_refiner_4D import Motion_refiner, MAX_NUM_OBJS\n",
    "from src.config import *\n",
    "from src.functions import *\n",
    "\n",
    "mr = Motion_refiner(load_models=True ,traj_n = traj_n, locality_factor=True, clip_only=False)\n",
    "feature_indices, obj_sim_indices, obj_poses_indices, traj_indices = mr.get_indices()\n",
    "embedding_indices = mr.embedding_indices\n",
    "\n",
    "#============================== load dataset ==========================================\n",
    "X,Y, data = mr.load_dataset(\"forces_only_f\", filter_data = True, base_path=data_folder)\n",
    "X_train, X_test, X_valid, y_train, y_test, y_valid, indices_train, indices_test, indices_val = mr.split_dataset(X, Y, test_size=0.2, val_size=0.1)\n",
    "\n",
    "\n",
    "def prepare_x(x):\n",
    "    objs = pad_array(list_to_wp_seq(x[:,obj_poses_indices],d=3),4,axis=-1) # no speed\n",
    "    trajs = list_to_wp_seq(x[:,traj_indices],d=4)\n",
    "    #   return np.concatenate([objs,trajs],axis = 1)\n",
    "    return trajs[:,:-1,:]\n",
    "\n",
    "test_dataset = tf.data.Dataset.from_tensor_slices((prepare_x(X_test),\n",
    "                                                    list_to_wp_seq(y_test,d=4),\n",
    "                                                    X_test[:,embedding_indices])).batch(X_test.shape[0])\n",
    "\n",
    "g = generator(test_dataset,stop=True,augment=False)\n",
    "x_t, y_t = next(g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'num_layers_enc': 1, 'num_layers_dec': 5, 'd_model': 400, 'dff': 512, 'num_heads': 8, 'dropout_rate': 0.1, 'wp_d': 4, 'num_emb_vec': 16, 'bs': 16, 'dense_n': 512, 'num_dense': 3, 'concat_emb': True, 'features_n': 793, 'optimizer': 'adam', 'norm_layer': True, 'activation': 'tanh', 'loss': 'mse'}\n",
      "loading weights:  /home/arthur/local_data/models/forces_onl/TF-num_layers_enc:1-num_layers_dec:5-d_model:400-dff:512-num_heads:8-dropout_rate:0.1-wp_d:4-num_emb_vec:16-bs:16-dense_n:512-num_dense:3-concat_emb:True-features_n:793-optimizer:adam-norm_layer:True-activation:tanh-loss:mse.h5\n"
     ]
    }
   ],
   "source": [
    "from src.TF4D_decoder_only import *\n",
    "\n",
    "model_path = models_folder+\"forces_onl/\"\n",
    "model_name = \"TF-num_layers_enc:1-num_layers_dec:5-d_model:400-dff:512-num_heads:8-dropout_rate:0.1-wp_d:4-num_emb_vec:16-bs:16-dense_n:512-num_dense:3-concat_emb:True-features_n:793-optimizer:adam-norm_layer:True-activation:tanh-loss:mse.h5\"\n",
    "\n",
    "model_file = model_path+model_name\n",
    "model = load_model(model_file, delimiter=\"-\")\n",
    "\n",
    "# metrics = evaluate_model(model, x_t, y_t)\n",
    "# models_metrics[model_tag] = metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "63/63 [==============================] - 3s 17ms/step\n"
     ]
    }
   ],
   "source": [
    "pred = model.predict(x_t)\n",
    "pred_d = np.array(data)[indices_test]\n",
    "for i,d in enumerate(pred_d):\n",
    "    d[\"forces\"] = pred[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "findfont: Font family ['Times New Roman'] not found. Falling back to DejaVu Sans.\n"
     ]
    }
   ],
   "source": [
    "%matplotlib qt\n",
    "indices = np.random.choice(range(len(indices_test)), 3)\n",
    "\n",
    "plt.close('all')\n",
    "# pred_t = np.transpose(pred[:,:,:2],[0,2,1])\n",
    "data_array = np.array(data)[indices_test[indices]]\n",
    "show_data4D(data_array,plot_output=False)\n",
    "# data_array = pred_d[indices]\n",
    "# show_data4D(data_array,plot_output=False)\n",
    "\n",
    "# show_data4D(data_array,plot_output=False,image_loader=mr.image_loader, color_traj=False, change_img_base=[\"/home/mirmi/Arthur/dataset/\",\"/home/tum/data/image_dataset/\"])\n",
    "# show_data4D(data_sample,plot_output=False)#,image_loader=mr.image_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cart9.txt\n",
      "---------------------------------------\n",
      "scoope a lot more down\n",
      "input (40, 4)\n",
      "DONE - computing textual embeddings (1, 768)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1it [00:00, 147.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DONE - computing similarity vectors \n",
      "DONE - concatenating \n",
      "X shape (1, 953)\n",
      "1/1 [==============================] - 0s 22ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%matplotlib qt\n",
    "\n",
    "def interpolate_traj(wps,traj_n=40, offset=[0,0,0,0]):\n",
    "    #create spline function\n",
    "    f, u = interpolate.splprep([wps[:,0],wps[:,1],wps[:,2]], s=0)\n",
    "    xint,yint,zint= interpolate.splev(np.linspace(0, 1, traj_n), f)\n",
    "\n",
    "    tck,u = interpolate.splprep([np.linspace(0,1,len(wps[:,3])), wps[:,3]])\n",
    "    velint_x, velint = interpolate.splev(np.linspace(0, 1, traj_n), tck)\n",
    "\n",
    "    traj = np.stack([xint,yint,zint,velint],axis=1)+offset\n",
    "    return traj\n",
    "\n",
    "\n",
    "def norm_traj_and_objs(t, o, margin=0.45):\n",
    "    pts_ = np.concatenate([o,t])\n",
    "\n",
    "    vel = pts_[:,3:]\n",
    "    pts = pts_[:,:3]\n",
    "\n",
    "    vel_min = np.min(vel,axis = 0)\n",
    "    vel_max = np.max(vel,axis = 0)\n",
    "    vel_norm = np.max(np.abs(vel_max-vel_min))\n",
    "    if vel_norm > 1e-10:\n",
    "        vel = ((vel-(vel_max-vel_min)/2)/vel_norm)*(1-margin)\n",
    "\n",
    "    else:\n",
    "        vel = vel-vel_min\n",
    "\n",
    "    pts_min = np.min(pts,axis = 0)\n",
    "    pts_max = np.max(pts,axis = 0)\n",
    "    pts_norm = np.max(np.abs(pts_max-pts_min))\n",
    "\n",
    "    # pts  = ((pts-pts_min)/pts_norm)*(1-margin)+margin/2-0.5\n",
    "    pts  = ((pts-(pts_max-pts_min)/2)/pts_norm)*(1-margin)\n",
    "\n",
    "    pts_new= np.concatenate([pts,vel],axis=-1)\n",
    "    o_new = pts_new[:o.shape[0],:]\n",
    "    t_new = pts_new[o.shape[0]:,:]\n",
    "\n",
    "    return t_new, o_new, [pts_norm, (pts_max-pts_min)/2,vel_norm, (vel_max-vel_min)/2, margin]\n",
    "\n",
    "def rescale(pts_, factor_list):\n",
    "\n",
    "    vel = pts_[:,3:]\n",
    "    pts = pts_[:,:3]\n",
    "\n",
    "    pts_norm, pts_avr,vel_norm, vel_avr, margin = factor_list\n",
    "    pts = pts/(1-margin)*pts_norm+pts_avr\n",
    "    vel = vel/(1-margin)*vel_norm+vel_avr\n",
    "    pts_new= np.concatenate([pts,vel],axis=-1)\n",
    "\n",
    "    return pts_new\n",
    "\n",
    "def modify_traj(mr, traj, obj_poses, text, obj_names,obj_poses_offset, show=False):\n",
    "\n",
    "    images = None\n",
    "\n",
    "    print(\"---------------------------------------\")\n",
    "    print(text)\n",
    "    # print(traj.shape, obj_poses.shape, text, obj_names,obj_poses_offset.shape)\n",
    "\n",
    "    traj_, obj_poses_, factor_list = norm_traj_and_objs(traj, obj_poses)\n",
    "    traj_int = interpolate_traj(traj_,traj_n=traj_n)\n",
    "    \n",
    "    obj_poses_ = obj_poses_[:,:3]\n",
    "\n",
    "    d = np2data(traj_int, obj_names, obj_poses_, text)[0]\n",
    "\n",
    "    pred, traj_in = mr.apply_interaction(model, d, text,  label=False, images=images, dec_only=True)\n",
    "\n",
    "    # %matplotlib qt\n",
    "    if show:\n",
    "        data_array = np.array([d])\n",
    "        # show_data4D(data_array, forces=pred, color_traj=False)\n",
    "        # print(traj_in.shape)\n",
    "        objs  = {}\n",
    "        obj_names = np.asarray(d[\"obj_names\"])\n",
    "        obj_pt = np.asarray(d[\"obj_poses\"])\n",
    "        for x,y,z,name in zip(obj_pt[:,0],obj_pt[:,1],obj_pt[:,2],obj_names):\n",
    "            objs[name] = {\"value\":{\"obj_p\":[x,y,z]}}\n",
    "        \n",
    "        # plot_samples(text,traj_[-2000:-1000],[], images=[], fig=None,objs=objs, colors = [\"#02b322\",\"#0c45c2\",\"#C2450C\"],alpha=[0.9,0.9],linewidth=1.0,\n",
    "        #         plot_voxels= False, color_traj = False, map_cost_f=None, labels=[], plot_speed=False, show=True, forces=None)\n",
    "\n",
    "        plot_samples(text,traj_in[0],[], images=[], fig=None,objs=objs, colors = [\"#02b322\",\"#0c45c2\",\"#C2450C\"],alpha=[0.9,0.9],linewidth=1.0,\n",
    "                plot_voxels= False, color_traj = False, map_cost_f=None, labels=[], plot_speed=False, show=True, forces=pred[0])\n",
    "    # \n",
    "    traj_new = rescale(pred[0], factor_list)\n",
    "    \n",
    "    traj_new = interpolate_traj(traj_new,traj_n=np.array(traj).shape[0])\n",
    "    # publish_simple_traj(traj_new,obj_poses+obj_poses_offset,new_traj_pub)\n",
    "    # publish_simple_traj(traj,obj_poses+obj_poses_offset,traj_pub)\n",
    "\n",
    "    return traj_new,d\n",
    "\n",
    "from src.motion_refiner_4D import Motion_refiner, MAX_NUM_OBJS\n",
    "\n",
    "forces_data_folder = \"/home/arthur/local_data/forces_data/\"\n",
    "forces_data_folder_new = \"/home/arthur/local_data/forces_data_new/\"\n",
    "text = \"scoope deeper\"\n",
    "text = \"scoope a lot more down\"\n",
    "# text = \"press a lot stronger down\"\n",
    "# text = \"press stronger down\"\n",
    "\n",
    "# text = \"scoope much deeper down\"\n",
    "\n",
    "obj_poses = np.array([[0,0,-0.5,0],[0,0.4,-0.5,0],[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]])\n",
    "obj_poses_offset = np.array([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]])\n",
    "obj_names = [\"sand\",\"table\"]\n",
    "plt.close()\n",
    "for force_file in os.listdir(forces_data_folder)[2:]:\n",
    "    if not \"9\" in force_file:\n",
    "        continue\n",
    "    # try:\n",
    "    print(force_file)\n",
    "    force_file_file = os.path.join(forces_data_folder,force_file)\n",
    "    traj=[]\n",
    "    with open(force_file_file) as f:\n",
    "        lines = f.readlines()\n",
    "    for i,l in enumerate(lines):\n",
    "        traj.append([float(c) for c in l[:-1].split(\",\")]+[0])\n",
    "\n",
    "    traj_new,d = modify_traj(mr, traj, obj_poses, text, obj_names,obj_poses_offset, show=True)\n",
    "    \n",
    "    force_file_file_new = os.path.join(forces_data_folder_new,force_file)\n",
    "    with open(force_file_file_new, 'w') as f:\n",
    "        for wp in traj_new:\n",
    "            line = str(wp[0])+\" ,\"+str(wp[1])+\" ,\"+str(wp[2])\n",
    "            f.write(line)\n",
    "            f.write('\\n')\n",
    "\n",
    "    # except:\n",
    "    #     pass\n",
    "    # plt.close() \n",
    "    # fig = plt.figure(figsize=(10,13))\n",
    "    # plt.plot(range(len(traj_new[:,0])),traj_new[:,2])\n",
    "    # plt.plot(range(len(traj_new[:,0])),traj_new[:,1])\n",
    "    # traj\n",
    "    break\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('py38_cu11_2')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "5de91200c9f9e1f8a0c28ceba668014be0fd55838e84400e0a7ad1d269192773"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
